# Copyright 2023 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")

package(
    default_visibility = ["//visibility:public"],
    features = ["layering_check"],
    licenses = ["notice"],  # Apache 2.0
)

iree_runtime_cc_library(
    name = "cuda",
    srcs = [
        "api.h",
        "cuda_allocator.c",
        "cuda_allocator.h",
        "cuda_buffer.c",
        "cuda_buffer.h",
        "cuda_device.c",
        "cuda_device.h",
        "cuda_driver.c",
        "event_pool.c",
        "event_pool.h",
        "event_semaphore.c",
        "event_semaphore.h",
        "graph_command_buffer.c",
        "graph_command_buffer.h",
        "memory_pools.c",
        "memory_pools.h",
        "native_executable.c",
        "native_executable.h",
        "nccl_channel.c",
        "nccl_channel.h",
        "nop_executable_cache.c",
        "nop_executable_cache.h",
        "stream_command_buffer.c",
        "stream_command_buffer.h",
        "timepoint_pool.c",
        "timepoint_pool.h",
    ],
    hdrs = [
        "api.h",
    ],
    deps = [
        ":dynamic_symbols",
        "//runtime/src/iree/base",
        "//runtime/src/iree/base:core_headers",
        "//runtime/src/iree/base/internal",
        "//runtime/src/iree/base/internal:arena",
        "//runtime/src/iree/base/internal:atomic_slist",
        "//runtime/src/iree/base/internal:event_pool",
        "//runtime/src/iree/base/internal:synchronization",
        "//runtime/src/iree/base/internal:threading",
        "//runtime/src/iree/base/internal:wait_handle",
        "//runtime/src/iree/base/internal/flatcc:parsing",
        "//runtime/src/iree/hal",
        "//runtime/src/iree/hal/utils:collective_batch",
        "//runtime/src/iree/hal/utils:deferred_command_buffer",
        "//runtime/src/iree/hal/utils:deferred_work_queue",
        "//runtime/src/iree/hal/utils:executable_debug_info",
        "//runtime/src/iree/hal/utils:file_transfer",
        "//runtime/src/iree/hal/utils:files",
        "//runtime/src/iree/hal/utils:queue_emulation",
        "//runtime/src/iree/hal/utils:queue_host_call_emulation",
        "//runtime/src/iree/hal/utils:resource_set",
        "//runtime/src/iree/hal/utils:semaphore_base",
        "//runtime/src/iree/hal/utils:stream_tracing",
        "//runtime/src/iree/schemas:cuda_executable_def_c_fbs",
        "//runtime/src/iree/schemas:executable_debug_info_c_fbs",
    ],
)

iree_runtime_cc_library(
    name = "dynamic_symbols",
    srcs = [
        "cuda_dynamic_symbols.c",
        "cuda_headers.h",
        "cuda_status_util.c",
        "nccl_dynamic_symbols.c",
        "nccl_headers.h",
        "nccl_status_util.c",
    ],
    hdrs = [
        "cuda_dynamic_symbols.h",
        "cuda_status_util.h",
        "nccl_dynamic_symbols.h",
        "nccl_status_util.h",
    ],
    textual_hdrs = [
        "cuda_dynamic_symbol_table.h",
        "nccl_dynamic_symbol_table.h",
    ],
    deps = [
        "//runtime/src/iree/base",
        "//runtime/src/iree/base/internal:dynamic_library",
        "@iree_cuda//:headers",
        "@nccl//:headers",
    ],
)

iree_runtime_cc_test(
    name = "dynamic_symbols_test",
    srcs = [
        "dynamic_symbols_test.cc",
    ],
    tags = ["driver=cuda"],
    deps = [
        ":dynamic_symbols",
        "//runtime/src/iree/base",
        "//runtime/src/iree/testing:gtest",
        "//runtime/src/iree/testing:gtest_main",
    ],
)
